{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import sys\n",
    "import yaml\n",
    "from torchvision import transforms, datasets\n",
    "import torchvision\n",
    "import numpy as np\n",
    "import os\n",
    "import pickle\n",
    "from sklearn import preprocessing\n",
    "from torch.utils.data.dataloader import DataLoader\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.datasets as dsets\n",
    "from torch.autograd import Variable\n",
    "from torch.nn import Parameter\n",
    "from torch import Tensor\n",
    "import torch.nn.functional as F\n",
    "import torchmetrics\n",
    "\n",
    "from Rosseta.resnet_base_network import ResNet18\n",
    "from util.loader_1 import DataTraffic\n",
    "import joblib\n",
    "from Rosseta.lstm import LSTMCell,LSTMModel\n",
    "\n",
    "input_dim = 32\n",
    "seq_dim = 16\n",
    "device = 'cpu'"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Helper Function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_features_from_encoder(encoder, loader):\n",
    "    \n",
    "    x_train = []\n",
    "    y_train = []\n",
    "\n",
    "    # get the features from the pre-trained model\n",
    "    for i, (x, y) in enumerate(loader):\n",
    "        with torch.no_grad():\n",
    "            feature_vector = encoder(x[0])\n",
    "            x_train.extend(feature_vector)\n",
    "            y_train.extend(y.numpy())\n",
    "\n",
    "    x_train = torch.stack(x_train)\n",
    "    y_train = torch.tensor(y_train)\n",
    "    return x_train, y_train\n",
    "\n",
    "def generate_encoder():\n",
    "    device = 'cpu'\n",
    "    config = yaml.load(open(\"Rosseta/model/checkpoints/config.yaml\", \"r\"), Loader=yaml.FullLoader)\n",
    "    encoder = ResNet18(**config['network'])\n",
    "\n",
    "    load_params = torch.load(os.path.join('/home/mininet/experiment/Cross-network/Rosseta/test/Rosseta/model/checkpoints/model.pth'),\n",
    "                            map_location=torch.device(torch.device(device)))\n",
    "\n",
    "    if 'online_network_state_dict' in load_params:\n",
    "        encoder.load_state_dict(load_params['online_network_state_dict'])\n",
    "        print(\"Parameters successfully loaded.\")\n",
    "    encoder = torch.nn.Sequential(*list(encoder.children())[:-1])    \n",
    "    encoder = encoder.to(device)\n",
    "    encoder.eval()\n",
    "    return encoder\n",
    "\n",
    "def create_data_loaders_from_arrays(X_test, y_test):\n",
    "    test = torch.utils.data.TensorDataset(X_test, y_test)\n",
    "    test_loader = torch.utils.data.DataLoader(test, batch_size=128, shuffle=True)\n",
    "    return test_loader\n",
    "\n",
    "def generate_testFeature(file_name, encoder):\n",
    "    test_dataset = DataTraffic(file_name, dataset = 'test')\n",
    "    test_loader = DataLoader(test_dataset, batch_size=128,\n",
    "                          num_workers=0, drop_last=False, shuffle=False)\n",
    "\n",
    "    x_test, y_test = get_features_from_encoder(encoder, test_loader)\n",
    "    x_test = torch.mean(x_test, dim=[2, 3])\n",
    "    scaler = pickle.load(open('util/sc.pkl','rb'))\n",
    "    x_test = scaler.transform(x_test).astype(np.float32)\n",
    "    x_test = np.pad(x_test, ((0, 0), (0, 512 - len(x_test[0]))), 'constant').reshape(-1, 1, 32, 16)\n",
    "    test = torch.utils.data.TensorDataset(torch.from_numpy(x_test), y_test)\n",
    "    test_loader = torch.utils.data.DataLoader(test, batch_size=128, shuffle=False)\n",
    "    return test_loader\n",
    "\n",
    "def runmodel(test_loader, model):\n",
    "    accuracy = torchmetrics.Accuracy()\n",
    "    recall = torchmetrics.Recall(average='macro', num_classes=10)\n",
    "    precision = torchmetrics.Precision(average='macro', num_classes=10)\n",
    "    f1score = torchmetrics.F1(average='macro', num_classes=10)\n",
    "    eval_loss = 0\n",
    "    eval_acc = 0\n",
    "\n",
    "    for images, labels in test_loader:\n",
    "        #######################\n",
    "        #  USE GPU FOR MODEL  #\n",
    "        #######################\n",
    "        if torch.cuda.is_available():\n",
    "            images = Variable(images.view(-1, seq_dim, input_dim).cuda())\n",
    "        else:\n",
    "            images = Variable(images.view(-1 , seq_dim, input_dim))\n",
    "        \n",
    "        # Forward pass only to get logits/output\n",
    "        outputs = model(images)\n",
    "        \n",
    "        # Get predictions from the maximum value\n",
    "        _, predicted = torch.max(outputs.data, 1)\n",
    "        \n",
    "        \n",
    "        labels = labels.cpu()\n",
    "        predicted = predicted.cpu()\n",
    "        accuracy(predicted, labels)\n",
    "        recall(predicted, labels)\n",
    "        precision(predicted, labels)\n",
    "        f1score(predicted, labels)\n",
    "\n",
    "\n",
    "    acc = accuracy.compute().data.detach()\n",
    "    print('Accuracy:',acc)\n",
    "    rec = recall.compute().data.detach()\n",
    "    prec = precision.compute().data.detach()\n",
    "    f1sc = f1score.compute().data.detach()\n",
    "    print(f\" Recall: {rec}, precision: {prec}, f1score: {f1sc}\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "testFiles"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Parameters successfully loaded.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:1805: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.\n",
      "  warnings.warn(\"nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.\")\n",
      "/usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:1794: UserWarning: nn.functional.tanh is deprecated. Use torch.tanh instead.\n",
      "  warnings.warn(\"nn.functional.tanh is deprecated. Use torch.tanh instead.\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: tensor(0.9476)\n",
      " Recall: 0.9519410133361816, precision: 0.9388858079910278, f1score: 0.9444727897644043\n"
     ]
    }
   ],
   "source": [
    "\n",
    "dir = 'data/'\n",
    "testFile = dir + 'test-cn-bin.csv'\n",
    "testFile1 = dir + 'test-kr-bin.csv'\n",
    "testFile2 = dir + 'test-us-bin.csv'\n",
    "testFile3 = dir + 'test-wifi-bin.csv'\n",
    "testFile4 = dir + 'test-4G-bin.csv'\n",
    "testFile5 = dir + 'test-3G-bin.csv'\n",
    "\n",
    "encoder = generate_encoder()\n",
    "test_loader = generate_testFeature(testFile5, encoder)\n",
    "model = torch.load('Rosseta/lstm-Rosseta.pth')\n",
    "runmodel(test_loader, model)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
